Multi-level example

Published

July 29, 2024

Modified

July 11, 2024

Assume you have 5-10 treatment arms to which patients are randomised and within each treatment arm are a mix of participants with respect to some characteristic that could influence the response. We are interested in the overall treatment effect and heterogeneity due to the subgroup.

Within this setting there are multiple levels of variation. There is the between treatment arm variation that characterises how different the treatment arms are. There is also the within treatment arm variation due to the subgroup.

We could adopt a range of assumptions to model the responses for each combination of treatment and subgroup.

  1. Model the responses for each treatment arm by subgroup combination independently; no information is shared between any of the combinations. This approach will recover the observed point estimates and could be achieved using a logistic regression to estimate each combination’s mean response.
  2. Model the treatment groups independently, but estimate the variation within each treatment arm due to subgroup membership by sharing a common variance parameter across all the treatment arms. This will reflect the observed mean response in each treatment arm but shrink the subgroup variation towards these means. It will only do this if there is sufficient data to inform the relevant variance estimate.
  3. Model both the between treatment arm variation and the within treatment arm variation by partitioning the total variation. This will tend to shrink the treatment arm means towards an overall mean and the subgroup estimates towards each treatment arm mean. The within group variation could be modelled for each treatment arm independently or be shared across all treatments.

Plus variations on these themes. All of the variance partition approaches require that there are sufficient groups to informed the variance parameters.

Parameter specification/generation

Assume that the true treatment group means are normally distributed around some non-zero mean with standard deviation \(s\) and that the subgroup means are normally distributed around each treatment group mean with a common standard deviation \(s_j\).

Parameter specification
get_par <- function(
    n_grp = 4, n_trt = 9,
    mu = 1,
    s = 0.1, s_j = 0.3
    ){
  
  l <- list()
  l$n_grp <- n_grp
  l$n_trt <- n_trt

  # overall mean effect across all intervention types
  l$mu <- mu
  # between intervention type variation
  l$s <- s
  # intervention type specific mean
  l$mu_j <- l$mu + rnorm(n_trt, 0, l$s)
  # within intervention variation attributable to group membership
  l$s_j <- s_j

  # trt x group effects
  l$mu_j_k <- do.call(rbind, lapply(seq_along(l$mu_j), function(i){
    rnorm(n_grp, l$mu_j[i], l$s_j)
  }))
  colnames(l$mu_j_k) <- paste0("strata", 1:ncol(l$mu_j_k))
  rownames(l$mu_j_k) <- paste0(1:nrow(l$mu_j_k))
  
  l$d_par <- CJ(
    j = factor(1:l$n_trt, levels = 1:l$n_trt),
    k = factor(1:l$n_grp, levels = 1:l$n_grp)
  )
  l$d_par[, mu := l$mu]
  l$d_par[, mu_j := l$mu_j[j]]
  l$d_par[, mu_j_k := l$mu_j_k[cbind(j,k)]]
  
  l
}

Any single data set will not allow us to recover the parameters exactly, but the differences between the estimates from the various modelling assumptions is informative as to the general patterns that arise.

Data generation function
get_data <- function(
    N = 2000, 
    par = NULL,
    ff = function(par, j, k){
      
      m1 <- cbind(j, k)
      eta = par$mu_j_k[m1] 
      eta
      
    }){
  
  # strata
  d <- data.table()
  
  # intervention - even allocation
  d[, j := sample(1:par$n_trt, size = N, replace = T)]
  # table(d$j)
  # uneven distribution of groups in the pop
  z <- rnorm(par$n_grp, 0, 0.5)
  d[, k := sample(1:par$n_grp, size = N, replace = T, prob = exp(z)/sum(exp(z)))]
  # d[, k := sample(1:par$n_grp, size = N, replace = T)]
  # table(d$j, d$k)
  
  d[, eta := ff(par, j, k)]
  
  d[, y := rbinom(.N, 1, plogis(eta))]
  
  d  
}

Generate data assuming the parameters below with the underlying truth shown in Figure 1. The dashed line shows the overall mean response, the crosses show the treatment arm means and the points show the subgroup heterogeneity around the treatment arm means.

True treatment arm by subgroup mean response
set.seed(1)
par <- get_par(n_grp = 5, n_trt = 8, mu = 1, s = 0.0, s_j = 0.5)
d <- get_data(N = 3000, par)


d_fig_2 <- unique(par$d_par[, .(mu_j, j)])
d_fig_2[1, label := "Treatment mean"]
       

p_fig <- ggplot(par$d_par, aes(x = j, y = mu_j_k, col = k)) +
  geom_point() +
  geom_hline(yintercept = par$mu, lwd = 0.25, lty = 2) +
  geom_text_repel(
    data = data.table(
      x = 2.5, y = par$mu, label = "Overall mean"
    ),
    aes(x = x, y = y, label = label),
                  inherit.aes = F,
                  nudge_x = 0.4,
                  nudge_y = 0.1,
                  segment.curvature = -0.1,
                  segment.ncp = 3,
                  segment.angle = 20,
                  box.padding = 2, max.overlaps = Inf, col = 1) +
  geom_text_repel(data = d_fig_2,
                  aes(x = j, y = mu_j, label = label), 
                  inherit.aes = F,
                  nudge_x = 0.4,
                  nudge_y = -0.05,
                  segment.curvature = -0.1,
                  segment.ncp = 3,
                  segment.angle = 20,
                  box.padding = 2, max.overlaps = Inf, col = 1) +
  geom_point(data = d_fig_2,
             aes(x = j, y = mu_j),
             inherit.aes = F, pch = 3, size = 3) +
  scale_x_discrete("Treatment type") +
  scale_y_continuous("Odds of success (log-odds)", 
                     breaks = seq(
                       round(min(par$d_par$mu_j_k), 1), 
                       round(max(par$d_par$mu_j_k), 1), 
                       by = 0.1)) +
  scale_color_discrete("Subgroup membership")

suppressWarnings(print(p_fig))
Figure 1: True treatment arm by subgroup mean response

Parameter estimation

The first model produces independent estimates of the means.

Fit model to simulated data
# mle - reference point
d_lm <- copy(d)
d_lm[, `:=`(j = factor(j), k = factor(k))]
f0 <- glm(y ~ j*k, data = d_lm, family = binomial)
X <- model.matrix(f0)
# CI
n_sim <- 1000
d_lm_j <- matrix(NA, nrow = par$n_trt, ncol = n_sim)
d_lm_j_k <- matrix(NA, nrow = par$n_trt * par$n_grp, ncol = n_sim)
for(i in 1:n_sim){
  ix <- sort(sample(1:nrow(d_lm), replace = T))
  f_boot <- glm(y ~ j*k, data = d_lm[ix], family = binomial)
  d_tmp_p <- cbind(
    d_lm[ix],
    mu = predict(f_boot)
  )
  d_lm_j[, i] <- d_tmp_p[, .(mean = mean(mu)), keyby = j][, mean]
  d_lm_j_k[, i] <- d_tmp_p[, .(mean = mean(mu)), keyby = .(k, j)][, mean]
}
# bootstrapped intervals for means on j and within group means
d_lm_j <- data.table(d_lm_j)
d_lm_j[, j := 1:.N]
d_lm_j <- melt(d_lm_j, id.vars = "j")
d_lm_j[, j := factor(j)]
# d_lm_j[, .(q5 = quantile(value, prob =0.05), 
#            q95 = quantile(value, prob = 0.95)), keyby = j]

d_lm_j_k <- data.table(d_lm_j_k)
d_lm_j_k <- cbind(CJ(k = 1:par$n_grp, j = 1:par$n_trt), d_lm_j_k)
d_lm_j_k <- melt(d_lm_j_k, id.vars = c("j", "k"))
d_lm_j_k[, `:=`(j = factor(j), k = factor(k))]
# d_lm_j_k[, .(q5 = quantile(value, prob =0.05), 
#            q95 = quantile(value, prob = 0.95)), keyby = .(k,j)]

d_lm[, eta_hat := predict(f0)]

# bayes
m1 <- cmdstanr::cmdstan_model("stan/mlm-ex-01.stan")
m2 <- cmdstanr::cmdstan_model("stan/mlm-ex-02.stan")
m3 <- cmdstanr::cmdstan_model("stan/mlm-ex-03.stan")

ld <- list(
  N = nrow(d), 
  y = d$y, 
  J = length(unique(d$j)), K = length(unique(d$k)),
  j = d$j, # intervention
  k = d$k, # subgroup
  P = ncol(X),
  X = X,
  s = 3,
  s_j = 3, # for the indep means offsets
  prior_only = 0
)

f1 <- m1$sample(
    ld, iter_warmup = 1000, iter_sampling = 1000,
    parallel_chains = 2, chains = 2, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 2 parallel chains...

Chain 2 finished in 6.5 seconds.
Chain 1 finished in 7.4 seconds.

Both chains finished successfully.
Mean chain execution time: 7.0 seconds.
Total execution time: 7.5 seconds.
Fit model to simulated data
f2 <- m2$sample(
    ld, iter_warmup = 1000, iter_sampling = 1000,
    parallel_chains = 2, chains = 2, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 2 parallel chains...

Chain 1 finished in 2.6 seconds.
Chain 2 finished in 2.6 seconds.

Both chains finished successfully.
Mean chain execution time: 2.6 seconds.
Total execution time: 2.7 seconds.
Fit model to simulated data
f3 <- m3$sample(
    ld, iter_warmup = 1000, iter_sampling = 1000,
    parallel_chains = 2, chains = 2, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 2 parallel chains...

Chain 2 finished in 2.4 seconds.
Chain 1 finished in 2.6 seconds.

Both chains finished successfully.
Mean chain execution time: 2.5 seconds.
Total execution time: 2.7 seconds.
Warning: 3 of 2000 (0.0%) transitions ended with a divergence.
See https://mc-stan.org/misc/warnings for details.

Figure 2 shows the estimated treatment arm means, and treatment arm by subgroup means for each of the combinations.

Patterns:

  1. The observed and modelled overall mean are very similar.
  2. The MLE and the no-pooling models produce similar estimates - they tend to reflect the observed response
  3. The treatment arm means estimated using MLE and no-pooling are more influenced by outliers when compared to the estimates for the treatment arm means under the partially pooled models.
Parameter estimates vs true values
d_1 <- data.table(t(f1$draws(variables = "eta", format = "matrix")))
d_1 <- cbind(d, d_1)
d_1 <- melt(d_1, id.vars = c("j", "k", "eta", "y"), variable.name = "i_draw")

d_2 <- data.table(t(f2$draws(variables = "eta", format = "matrix")))
d_2 <- cbind(d, d_2)
d_2 <- melt(d_2, id.vars = c("j", "k", "eta", "y"), variable.name = "i_draw")

d_3 <- data.table(t(f3$draws(variables = "eta", format = "matrix")))
d_3 <- cbind(d, d_3)
d_3 <- melt(d_3, id.vars = c("j", "k", "eta", "y"), variable.name = "i_draw")

d_mu_j <- rbind(
  d_1[, .(mu = mean(value)), keyby = .(j, i_draw)][
  , .(desc = "no pooling", 
      mean = mean(mu), 
      q5 = quantile(mu, prob = 0.05), 
      q95 = quantile(mu, prob = 0.95)), keyby = j],
  d_2[, .(mu = mean(value)), keyby = .(j, i_draw)][
  , .(desc = "partial pool (trt)", 
      mean = mean(mu), 
      q5 = quantile(mu, prob = 0.05), 
      q95 = quantile(mu, prob = 0.95)), keyby = j],
  d_3[, .(mu = mean(value)), keyby = .(j, i_draw)][
  , .(desc = "partial pool (trt+subgrp)", 
      mean = mean(mu), 
      q5 = quantile(mu, prob = 0.05), 
      q95 = quantile(mu, prob = 0.95)), keyby = j]
)

# d_mu_j <- rbind(
#   data.table(f2$summary(variables = c(
#     "mu_j"
#     )))[, .(desc = "partial pool (trt)", variable, mean, q5, q95)],
#   data.table(f3$summary(variables = c(
#     "mu_j"
#     )))[, .(desc = "partial pool (trt+subgrp)", variable, mean, q5, q95)], 
#   fill = T)
# d_mu_j[, j := gsub("mu_j[", "", variable, fixed = T)]
# d_mu_j[, j := gsub("]", "", j, fixed = T)]

# d_mu_j <- rbind(
#   d_mu_j,
#   d_1[, .(mu = mean(value)), keyby = .(j, i_draw)][
#   , .(desc = "no pooling", 
#       mean = mean(mu), 
#       q5 = quantile(mu, prob = 0.05), 
#       q95 = quantile(mu, prob = 0.95)), keyby = j],
#   fill = T)
  

# MLE
d_mle <- cbind(
  desc = "mle", 
  d_lm[, .(mean = mean(eta_hat)), keyby = .(j)])
d_mle <- merge(
  d_mle, 
  d_lm_j[, .(q5 = quantile(value, prob =0.05),
           q95 = quantile(value, prob = 0.95)), keyby = j], 
  by = "j"
)

d_mu_j <- rbind(d_mu_j, d_mle, fill = T)

# Observed
# Due to the imbalance between the subgroups, need to weight the
# contributions to align (approx) with mle.
# Easy here because only one other variable that we need to average
# over.
d_obs <- merge(
  d[, .(N_j = .N), keyby = .(j)],
  d[, .(mu = qlogis(mean(y)), .N), keyby = .(j, k)],
  by = "j"
)
d_mu_j <- rbind(
  d_mu_j,
  d_obs[, .(
    desc = "observed", 
    mean = sum(mu * N / N_j)), keyby = j]
  , fill = T
  )


d_mu_j[, desc := factor(desc, levels = c(
  "observed", "mle", "no pooling", "partial pool (trt)", "partial pool (trt+subgrp)"
))]

# Grand mean
# d_mu <- rbind(
#   data.table(f2$summary(variables = c(
#     "mu"
#     )))[, .(desc = "partial pool (trt)", variable, mean, q5, q95)],
#   data.table(f3$summary(variables = c(
#     "mu"
#     )))[, .(desc = "partial pool (trt+subgrp)", variable, mean, q5, q95)]
# )
  

# Same deal with grand mean, you need to weight the individual level
# contributions in order to align with the model estimates

# d_mu <- rbind(
#   d_mu, 
#   d[, .(mu = qlogis(mean(y)), 
#         w = .N/nrow(d)), keyby = .(j, k)][
#           , .(desc = "observed", mean = sum(mu * w))], 
#   fill = T
# )  

d_mu <- d[, .(mu = qlogis(mean(y)), 
         w = .N/nrow(d)), keyby = .(j, k)][
           , .(desc = "observed", mean = sum(mu * w))]


p_fig <- ggplot(d_mu_j, aes(x = j, y = mean, group = desc, col = desc)) +
  geom_point(position = position_dodge(width = 0.6)) +
  geom_linerange(aes(ymin=q5, ymax=q95), position = position_dodge2(width = 0.6)) +
  geom_hline(data = d_mu,
             aes(yintercept = mean, lty = desc),
             lwd = 0.3, col = 1) +
  scale_x_discrete("Treatment arm") +
  scale_y_continuous("log-odds treatment success", breaks = seq(-3, 3, by = 0.2))  +
  scale_color_discrete("") +
  scale_linetype_discrete("") +
  geom_text(data = d[, .N, keyby = .(j)],
            aes(x = j, y = min(d_mu_j$q5, na.rm = T) - 0.1, 
                label = N), inherit.aes = F) +
  theme(legend.position="bottom", legend.box="vertical", legend.margin=margin())

suppressWarnings(print(p_fig))  
Figure 2: Parameter estimates vs true values
Parameter estimates vs true values
# Posterior
# d_mu_j_k <- rbind(
#   data.table(f2$summary(variables = c(
#     "mu_j_k"
#     )))[, .(desc = "partial pool (trt)", variable, mean, q5, q95)],
#   data.table(f3$summary(variables = c(
#     "mu_j_k"
#     )))[, .(desc = "partial pool (trt+subgrp)", variable, mean, q5, q95)]
# )
# d_mu_j_k[, j := substr(variable, 8, 8)]
# d_mu_j_k[, k := substr(variable, 10, 10)]

# Manually calculate intervals for independent model
d_mu_j_k <- rbind(
  # d_mu_j_k,
  d_1[, .(mu = mean(value)), keyby = .(k, j, i_draw)][
  , .(desc = "no pooling", 
      mean = mean(mu), 
      q5 = quantile(mu, prob = 0.05), 
      q95 = quantile(mu, prob = 0.95)), keyby = .(k, j)],
  d_2[, .(mu = mean(value)), keyby = .(k, j, i_draw)][
  , .(desc = "partial pool (trt)", 
      mean = mean(mu), 
      q5 = quantile(mu, prob = 0.05), 
      q95 = quantile(mu, prob = 0.95)), keyby = .(k, j)],
  d_3[, .(mu = mean(value)), keyby = .(k, j, i_draw)][
  , .(desc = "partial pool (trt+subgrp)", 
      mean = mean(mu), 
      q5 = quantile(mu, prob = 0.05), 
      q95 = quantile(mu, prob = 0.95)), keyby = .(k, j)]
  )

# MLE
d_mle <- cbind(
  desc = "mle", 
  d_lm[, .(mean = mean(eta_hat)), keyby = .(k, j)])
d_mle <- merge(
  d_mle, 
  d_lm_j_k[, .(q5 = quantile(value, prob =0.05),
               q95 = quantile(value, prob = 0.95)), keyby = .(k,j)], 
  by = c("j","k")
)

d_mu_j_k <- rbind(d_mu_j_k, d_mle, fill = T)

# Observed
d_mu_j_k <- rbind(
  d_mu_j_k, 
  d[, .(desc = "observed", mean = qlogis(mean(y))), keyby = .(k, j)],
  fill = T)


d_mu_j_k[, desc := factor(desc, levels = c(
  "observed", "mle", "no pooling", "partial pool (trt)", "partial pool (trt+subgrp)"
))]

# 
d_mu_obs <- d[, .(mu = qlogis(mean(y)))]
d_mu_mod <- data.table(f3$summary(variables = c(
    "mu"
    )))[, .(model = "no pooling", variable, mean, q5, q95)]

p_fig <- ggplot(d_mu_j_k, aes(x = k, y = mean, group = desc, col = desc)) +
  geom_point(position = position_dodge(width = 0.6)) +
  geom_linerange(aes(ymin=q5, ymax=q95), position = position_dodge2(width = 0.6)) +
  scale_x_discrete("Subgroup") +
  scale_y_continuous("log-odds treatment success", breaks = seq(-3, 3, by = 0.2)) +
  geom_hline(data = d_mu_j[desc == "observed"],
             aes(yintercept = mean, group = desc),
             lwd = 0.3, col = 1) +
  geom_text(data = d[, .N, keyby = .(j, k)],
            aes(x = k, y = min(d_mu_j_k$q5, na.rm = T) - 0.1, label = N), 
            inherit.aes = F) +
  facet_wrap(~paste0("Treatment ", j)) 

suppressWarnings(print(p_fig))
Figure 3: Parameter estimates vs true values

The mlm partitions the variance into a between treatment variance part that characterises the variation in the treatment arms and a within treatment variance that characterises the variation due to the subgroups.

The prior is shown as the dashed line and the posterior the black solid line. Clearly, something has been learnt about the variation from the data; the between group variation (variation due to treatment) is large relative to the within group variation (variation due to subgroups).

Variance components (revision)
d_fig <- rbind(
  cbind(desc = "partial pool (trt)",
        data.table(f2$draws(variables = c("s_j"), format = "matrix"))),
  cbind(desc = "partial pool (trt+subgrp)",
        data.table(f3$draws(variables = c("s", "s_j"), format = "matrix"))  
  ), fill = T
)
d_fig <- melt(d_fig, id.vars = "desc")
d_fig[variable == "s", label := "variation b/w"]
d_fig[variable == "s_j", label := "variation w/in"]

d_pri <- CJ(
  variable = c("s", "s_j"),
  desc = c("partial pool (trt)", "partial pool (trt+subgrp)"),
  x = seq(min(d_fig$value, na.rm = T), 
          max(d_fig$value, na.rm = T), len = 500)
)
d_pri[, y := fGarch::dstd(x, nu = 3, mean = 0, sd = 2)]
d_pri[variable == "s", label := "variation b/w"]
d_pri[variable == "s_j", label := "variation w/in"]
d_pri[desc == "partial pool (trt)" & variable == "s", y := NA]
# 
# d_smry <- d_fig[, .(mu = mean(value)), keyby = .(label)]

p_fig <- ggplot(d_fig, aes(x = value, group = variable)) + 
  geom_density() +
  geom_line(
    data = d_pri, 
    aes(x = x, y = y), col = 2, lwd = 0.2
  ) +
  # geom_vline(data = d_smry,  
  #            aes(xintercept = mu), col = 2, lwd = 0.3) +
  # stat_function(fun = fGarch::dstd, 
  #               args = list(nu = 3, mean = 0, sd = 1), 
  #               col = 3, lty = 2) +
  scale_x_continuous("Standard deviation") +
  scale_y_continuous("Density") +
  facet_wrap(desc+label~., ncol = 2)

suppressWarnings(print(p_fig))
Figure 4: Between and within SD

Lot of within group variation, little between group variation. Discuss implications

Code
set.seed(1)
par <- get_par(n_grp = 2, n_trt = 2, mu = 1, s = 0.0, s_j = 0.5)
d <- get_data(N = 3000, par)
Fit model to simulated data
# mle - reference point
d_lm <- copy(d)
d_lm[, `:=`(j = factor(j), k = factor(k))]
f0 <- glm(y ~ j*k, data = d_lm, family = binomial)
X <- model.matrix(f0)
# CI
n_sim <- 1000
d_lm_j <- matrix(NA, nrow = par$n_trt, ncol = n_sim)
d_lm_j_k <- matrix(NA, nrow = par$n_trt * par$n_grp, ncol = n_sim)
for(i in 1:n_sim){
  ix <- sort(sample(1:nrow(d_lm), replace = T))
  f_boot <- glm(y ~ j*k, data = d_lm[ix], family = binomial)
  d_tmp_p <- cbind(
    d_lm[ix],
    mu = predict(f_boot)
  )
  d_lm_j[, i] <- d_tmp_p[, .(mean = mean(mu)), keyby = j][, mean]
  d_lm_j_k[, i] <- d_tmp_p[, .(mean = mean(mu)), keyby = .(k, j)][, mean]
}
# bootstrapped intervals for means on j and within group means
d_lm_j <- data.table(d_lm_j)
d_lm_j[, j := 1:.N]
d_lm_j <- melt(d_lm_j, id.vars = "j")
d_lm_j[, j := factor(j)]
# d_lm_j[, .(q5 = quantile(value, prob =0.05), 
#            q95 = quantile(value, prob = 0.95)), keyby = j]

d_lm_j_k <- data.table(d_lm_j_k)
d_lm_j_k <- cbind(CJ(k = 1:par$n_grp, j = 1:par$n_trt), d_lm_j_k)
d_lm_j_k <- melt(d_lm_j_k, id.vars = c("j", "k"))
d_lm_j_k[, `:=`(j = factor(j), k = factor(k))]
# d_lm_j_k[, .(q5 = quantile(value, prob =0.05), 
#            q95 = quantile(value, prob = 0.95)), keyby = .(k,j)]

d_lm[, eta_hat := predict(f0)]

# bayes
m1 <- cmdstanr::cmdstan_model("stan/mlm-ex-01.stan")
m2 <- cmdstanr::cmdstan_model("stan/mlm-ex-02.stan")
m3 <- cmdstanr::cmdstan_model("stan/mlm-ex-03.stan")

ld <- list(
  N = nrow(d), 
  y = d$y, 
  J = length(unique(d$j)), K = length(unique(d$k)),
  j = d$j, # intervention
  k = d$k, # subgroup
  P = ncol(X),
  X = X,
  s = 3,
  s_j = 3, # for the indep means offsets
  prior_only = 0
)

f1 <- m1$sample(
    ld, iter_warmup = 1000, iter_sampling = 1000,
    parallel_chains = 2, chains = 2, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 2 parallel chains...

Chain 1 finished in 1.4 seconds.
Chain 2 finished in 1.5 seconds.

Both chains finished successfully.
Mean chain execution time: 1.4 seconds.
Total execution time: 1.5 seconds.
Fit model to simulated data
f2 <- m2$sample(
    ld, iter_warmup = 1000, iter_sampling = 1000,
    parallel_chains = 2, chains = 2, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 2 parallel chains...

Chain 1 finished in 6.1 seconds.
Chain 2 finished in 6.6 seconds.

Both chains finished successfully.
Mean chain execution time: 6.4 seconds.
Total execution time: 6.7 seconds.
Warning: 4 of 2000 (0.0%) transitions ended with a divergence.
See https://mc-stan.org/misc/warnings for details.
Fit model to simulated data
f3 <- m3$sample(
    ld, iter_warmup = 1000, iter_sampling = 1000,
    parallel_chains = 2, chains = 2, refresh = 0, show_exceptions = F,
    max_treedepth = 10)
Running MCMC with 2 parallel chains...

Chain 2 finished in 9.6 seconds.
Chain 1 finished in 15.5 seconds.

Both chains finished successfully.
Mean chain execution time: 12.6 seconds.
Total execution time: 15.5 seconds.
Warning: 25 of 2000 (1.0%) transitions ended with a divergence.
See https://mc-stan.org/misc/warnings for details.
Parameter estimates vs true values
d_1 <- data.table(t(f1$draws(variables = "eta", format = "matrix")))
d_1 <- cbind(d, d_1)
d_1 <- melt(d_1, id.vars = c("j", "k", "eta", "y"), variable.name = "i_draw")

d_2 <- data.table(t(f2$draws(variables = "eta", format = "matrix")))
d_2 <- cbind(d, d_2)
d_2 <- melt(d_2, id.vars = c("j", "k", "eta", "y"), variable.name = "i_draw")

d_3 <- data.table(t(f3$draws(variables = "eta", format = "matrix")))
d_3 <- cbind(d, d_3)
d_3 <- melt(d_3, id.vars = c("j", "k", "eta", "y"), variable.name = "i_draw")

d_mu_j <- rbind(
  d_1[, .(mu = mean(value)), keyby = .(j, i_draw)][
  , .(desc = "no pooling", 
      mean = mean(mu), 
      q5 = quantile(mu, prob = 0.05), 
      q95 = quantile(mu, prob = 0.95)), keyby = j],
  d_2[, .(mu = mean(value)), keyby = .(j, i_draw)][
  , .(desc = "partial pool (trt)", 
      mean = mean(mu), 
      q5 = quantile(mu, prob = 0.05), 
      q95 = quantile(mu, prob = 0.95)), keyby = j],
  d_3[, .(mu = mean(value)), keyby = .(j, i_draw)][
  , .(desc = "partial pool (trt+subgrp)", 
      mean = mean(mu), 
      q5 = quantile(mu, prob = 0.05), 
      q95 = quantile(mu, prob = 0.95)), keyby = j]
)

  

# MLE
d_mle <- cbind(
  desc = "mle", 
  d_lm[, .(mean = mean(eta_hat)), keyby = .(j)])
d_mle <- merge(
  d_mle, 
  d_lm_j[, .(q5 = quantile(value, prob =0.05),
           q95 = quantile(value, prob = 0.95)), keyby = j], 
  by = "j"
)

d_mu_j <- rbind(d_mu_j, d_mle, fill = T)

# Observed
# Due to the imbalance between the subgroups, need to weight the
# contributions to align (approx) with mle.
# Easy here because only one other variable that we need to average
# over.
d_obs <- merge(
  d[, .(N_j = .N), keyby = .(j)],
  d[, .(mu = qlogis(mean(y)), .N), keyby = .(j, k)],
  by = "j"
)
d_mu_j <- rbind(
  d_mu_j,
  d_obs[, .(
    desc = "observed", 
    mean = sum(mu * N / N_j)), keyby = j]
  , fill = T
  )


d_mu_j[, desc := factor(desc, levels = c(
  "observed", "mle", "no pooling", "partial pool (trt)", "partial pool (trt+subgrp)"
))]



d_mu <- d[, .(mu = qlogis(mean(y)), 
         w = .N/nrow(d)), keyby = .(j, k)][
           , .(desc = "observed", mean = sum(mu * w))]


p_fig <- ggplot(d_mu_j, aes(x = j, y = mean, group = desc, col = desc)) +
  geom_point(position = position_dodge(width = 0.6)) +
  geom_linerange(aes(ymin=q5, ymax=q95), position = position_dodge2(width = 0.6)) +
  geom_hline(data = d_mu,
             aes(yintercept = mean, lty = desc),
             lwd = 0.3, col = 1) +
  scale_x_discrete("Treatment arm") +
  scale_y_continuous("log-odds treatment success", breaks = seq(-3, 3, by = 0.2))  +
  scale_color_discrete("") +
  scale_linetype_discrete("") +
  geom_text(data = d[, .N, keyby = .(j)],
            aes(x = j, y = min(d_mu_j$q5, na.rm = T) - 0.1, 
                label = N), inherit.aes = F) +
  theme(legend.position="bottom", legend.box="vertical", legend.margin=margin())

suppressWarnings(print(p_fig))  
Figure 5: Parameter estimates vs true values
Parameter estimates vs true values
# Manually calculate intervals for independent model
d_mu_j_k <- rbind(
  # d_mu_j_k,
  d_1[, .(mu = mean(value)), keyby = .(k, j, i_draw)][
  , .(desc = "no pooling", 
      mean = mean(mu), 
      q5 = quantile(mu, prob = 0.05), 
      q95 = quantile(mu, prob = 0.95)), keyby = .(k, j)],
  d_2[, .(mu = mean(value)), keyby = .(k, j, i_draw)][
  , .(desc = "partial pool (trt)", 
      mean = mean(mu), 
      q5 = quantile(mu, prob = 0.05), 
      q95 = quantile(mu, prob = 0.95)), keyby = .(k, j)],
  d_3[, .(mu = mean(value)), keyby = .(k, j, i_draw)][
  , .(desc = "partial pool (trt+subgrp)", 
      mean = mean(mu), 
      q5 = quantile(mu, prob = 0.05), 
      q95 = quantile(mu, prob = 0.95)), keyby = .(k, j)]
  )

# MLE
d_mle <- cbind(
  desc = "mle", 
  d_lm[, .(mean = mean(eta_hat)), keyby = .(k, j)])
d_mle <- merge(
  d_mle, 
  d_lm_j_k[, .(q5 = quantile(value, prob =0.05),
               q95 = quantile(value, prob = 0.95)), keyby = .(k,j)], 
  by = c("j","k")
)

d_mu_j_k <- rbind(d_mu_j_k, d_mle, fill = T)

# Observed
d_mu_j_k <- rbind(
  d_mu_j_k, 
  d[, .(desc = "observed", mean = qlogis(mean(y))), keyby = .(k, j)],
  fill = T)


d_mu_j_k[, desc := factor(desc, levels = c(
  "observed", "mle", "no pooling", "partial pool (trt)", "partial pool (trt+subgrp)"
))]

# 
d_mu_obs <- d[, .(mu = qlogis(mean(y)))]
d_mu_mod <- data.table(f3$summary(variables = c(
    "mu"
    )))[, .(model = "no pooling", variable, mean, q5, q95)]

p_fig <- ggplot(d_mu_j_k, aes(x = k, y = mean, group = desc, col = desc)) +
  geom_point(position = position_dodge(width = 0.6)) +
  geom_linerange(aes(ymin=q5, ymax=q95), position = position_dodge2(width = 0.6)) +
  scale_x_discrete("Subgroup") +
  scale_y_continuous("log-odds treatment success", breaks = seq(-3, 3, by = 0.2)) +
  geom_hline(data = d_mu_j[desc == "observed"],
             aes(yintercept = mean, group = desc),
             lwd = 0.3, col = 1) +
  geom_text(data = d[, .N, keyby = .(j, k)],
            aes(x = k, y = min(d_mu_j_k$q5, na.rm = T) - 0.1, label = N), 
            inherit.aes = F) +
  facet_wrap(~paste0("Treatment ", j)) 

suppressWarnings(print(p_fig))
Figure 6: Parameter estimates vs true values
Variance components (revision)
d_fig <- rbind(
  cbind(desc = "partial pool (trt)",
        data.table(f2$draws(variables = c("s_j"), format = "matrix"))),
  cbind(desc = "partial pool (trt+subgrp)",
        data.table(f3$draws(variables = c("s", "s_j"), format = "matrix"))  
  ), fill = T
)
d_fig <- melt(d_fig, id.vars = "desc")
d_fig[variable == "s", label := "variation b/w"]
d_fig[variable == "s_j", label := "variation w/in"]

d_pri <- CJ(
  variable = c("s", "s_j"),
  desc = c("partial pool (trt)", "partial pool (trt+subgrp)"),
  x = seq(min(d_fig$value, na.rm = T), 
          max(d_fig$value, na.rm = T), len = 500)
)
d_pri[, y := fGarch::dstd(x, nu = 3, mean = 0, sd = 2)]
d_pri[variable == "s", label := "variation b/w"]
d_pri[variable == "s_j", label := "variation w/in"]
d_pri[desc == "partial pool (trt)" & variable == "s", y := NA]
# 
# d_smry <- d_fig[, .(mu = mean(value)), keyby = .(label)]

p_fig <- ggplot(d_fig, aes(x = value, group = variable)) + 
  geom_density() +
  geom_line(
    data = d_pri, 
    aes(x = x, y = y), col = 2, lwd = 0.2
  ) +
  # geom_vline(data = d_smry,  
  #            aes(xintercept = mu), col = 2, lwd = 0.3) +
  # stat_function(fun = fGarch::dstd, 
  #               args = list(nu = 3, mean = 0, sd = 1), 
  #               col = 3, lty = 2) +
  scale_x_continuous("Standard deviation") +
  scale_y_continuous("Density") +
  facet_wrap(desc+label~., ncol = 2)

suppressWarnings(print(p_fig))
Figure 7: Between and within SD